Skip to content

Conversation

@titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms commented Oct 31, 2025

This pull request introduces significant improvements and expanded support for multi-head attention kernels in ONNX Runtime, particularly focusing on supporting both 3D (BSNH) and 4D (BNSH) QKV input formats. The changes enhance flexibility, correctness, and maintainability for attention operations across CPU and CUDA implementations.

Expanded QKV Input Format Support

  • Added support for 4D QKV input format (Q_K_V_BNSH) in CUDA attention kernels, including proper handling for both cases with and without past/present states, and enforcing that bias is not supported for this format. This includes logic to avoid unnecessary transposes and to write outputs directly when possible. [1] [2] [3] [4] [5] [6] [7]

Kernel and Operator Documentation Updates

  • Updated OperatorKernels.md to document the new Attention operator inputs and outputs for both 3D and 4D formats, specifying supported tensor types for each input.

Correctness and Consistency Fixes

  • Fixed the computation of causal attention indices in CUDA softmax kernels by clarifying and correcting the offset calculation for causal masking. [1] [2] [3] [4]
  • Updated workspace allocation logic for QKV preparation to ensure correct workspace usage for new formats.

Attention Parameter and Helper Refactoring

  • Added is_output_bnsh field to AttentionParameters to indicate output format and updated logic to use this for output placement and transposition decisions. [1] [2]
  • Refactored CPU attention implementation to use the new attention_helper namespace for output mode enums and output shape computation, improving code clarity and maintainability. [1] [2] [3]

Minor Cleanups

  • Removed outdated asserts and improved debug output strings for QKV preparation functions to clarify format and state handling. [1] [2] [3]

These changes collectively improve the flexibility, correctness, and maintainability of attention kernel implementations in ONNX Runtime, especially for advanced transformer models and large language model workloads.

NOT supported in this PR

  • Boolean mask
  • GQA
  • Softcap
  • Softmax precision
  • qk_output_mode other than -1 and 0

@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Nov 19, 2025
@titaiwangms titaiwangms requested a review from tianleiwu January 12, 2026 17:53
@titaiwangms titaiwangms modified the milestone: 1.24.0 Jan 13, 2026
@tianleiwu tianleiwu changed the title Attenion(23) CUDA Attention(23) CUDA Jan 14, 2026
tianleiwu
tianleiwu previously approved these changes Jan 14, 2026
@titaiwangms titaiwangms enabled auto-merge (squash) January 14, 2026 21:25
@titaiwangms titaiwangms merged commit a3e477e into main Jan 15, 2026
88 of 90 checks passed
@titaiwangms titaiwangms deleted the titaiwang/support_attention_cuda branch January 15, 2026 02:57
titaiwangms added a commit that referenced this pull request Jan 15, 2026
titaiwangms added a commit that referenced this pull request Jan 15, 2026
This pull request introduces significant improvements and expanded
support for multi-head attention kernels in ONNX Runtime, particularly
focusing on supporting both 3D (`BSNH`) and 4D (`BNSH`) QKV input
formats. The changes enhance flexibility, correctness, and
maintainability for attention operations across CPU and CUDA
implementations.

### Expanded QKV Input Format Support

* Added support for 4D QKV input format (`Q_K_V_BNSH`) in CUDA attention
kernels, including proper handling for both cases with and without
past/present states, and enforcing that bias is not supported for this
format. This includes logic to avoid unnecessary transposes and to write
outputs directly when possible.
[[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R264-R265)
[[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R343-R354)
[[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R388-L388)
[[4]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R426-R435)
[[5]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716)
[[6]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R747-R748)
[[7]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791)

### Kernel and Operator Documentation Updates

* Updated `OperatorKernels.md` to document the new `Attention` operator
inputs and outputs for both 3D and 4D formats, specifying supported
tensor types for each input.

### Correctness and Consistency Fixes

* Fixed the computation of causal attention indices in CUDA softmax
kernels by clarifying and correcting the offset calculation for causal
masking.
[[1]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL168-R168)
[[2]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL244-R244)
[[3]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL336-R336)
[[4]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL442-R442)
* Updated workspace allocation logic for QKV preparation to ensure
correct workspace usage for new formats.

### Attention Parameter and Helper Refactoring

* Added `is_output_bnsh` field to `AttentionParameters` to indicate
output format and updated logic to use this for output placement and
transposition decisions.
[[1]](diffhunk://#diff-e742290164e1e1fa0152840db2a1b83354e153153df19a2762b58655e49b7f9bR37)
[[2]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791)
* Refactored CPU attention implementation to use the new
`attention_helper` namespace for output mode enums and output shape
computation, improving code clarity and maintainability.
[[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R5)
[[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L118-R125)
[[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L143-R149)

### Minor Cleanups

* Removed outdated asserts and improved debug output strings for QKV
preparation functions to clarify format and state handling.
[[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L254)
[[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L363)
[[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716)

These changes collectively improve the flexibility, correctness, and
maintainability of attention kernel implementations in ONNX Runtime,
especially for advanced transformer models and large language model
workloads.

**NOT supported in this PR**
- Boolean mask
- GQA
- Softcap
- Softmax precision
- qk_output_mode other than -1 and 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants